#! /bin/bash

set -u

HOST_ENGINE=no
COLLECT_INTERVAL_MS=1000
OUTPUT_FILE=/run/prometheus/dcgm.prom

usage() {
    echo "usage: $0 [-h] [-e] [-o output_file] [-d collect_interval_ms (>=100)]" >&2
}

options=$(getopt -o heo:d: -- "$@")
if [ $? -ne 0 ]; then
    usage && exit 1
fi
eval set -- "${options}"

while true; do
    case "$1" in
        -h) usage && exit 1;;
        -e) HOST_ENGINE=yes; shift;;
        -o) OUTPUT_FILE=$2; shift 2;;
        -d) COLLECT_INTERVAL_MS=$2; shift 2;;
        *)  break;;
    esac
done

if [ "${COLLECT_INTERVAL_MS}" -lt 100 ]; then
    usage && exit 1
fi

mkdir -p $(dirname ${OUTPUT_FILE})
trap 'echo "Caught signal, terminating..."' HUP INT QUIT PIPE TERM

if [ "${HOST_ENGINE}" = "yes" ]; then
    echo "Starting NVIDIA host engine..."
    nv-hostengine 2> /dev/null
fi

echo "Collecting metrics at ${OUTPUT_FILE} every ${COLLECT_INTERVAL_MS}ms..."

dcgmi dmon -d "${COLLECT_INTERVAL_MS}" -e \
"54,"\
"100,101,"\
"140,150,155,156,"\
"200,201,202,203,204,206,207,"\
"230,240,241,242,243,244,245,246,"\
"251,252,"\
"310,311,312,313,"\
"390,391,392,"\
"409,419,429,439" | \
awk -v "out=${OUTPUT_FILE}" -v "ngpus=$(nvidia-smi -L | wc -l)" '

function metric(name, type, help, value) {
    if (value !~ "N/A") {
        if (gpu == 0) {
            printf "# HELP dcgm_%s %s\n", name, help > out".swp"
            printf "# TYPE dcgm_%s %s\n", name, type > out".swp"
        }
        printf "dcgm_%s{gpu=\"%s\",uuid=\"%s\"} %s\n", name, gpu, uuid, value > out".swp"
    }
}
(NF && NR > 2 && !($1 ~ "^#" || $1 ~ "^Id")) {
    # Labels
    i = 1
    gpu = $(i++)                                                                                                      # field 0 (implicit)
    uuid = $(i++)                                                                                                     # field 54

    # Clocks
    metric("sm_clock", "gauge", "SM clock frequency (in MHz).", $(i++))                                               # field 100
    metric("memory_clock", "gauge", "Memory clock frequency (in MHz).", $(i++))                                       # field 101

    # Temperature
    metric("memory_temp", "gauge", "Memory temperature (in C).", $(i++))                                              # field 140
    metric("gpu_temp", "gauge", "GPU temperature (in C).", $(i++))                                                    # field 150

    # Power
    metric("power_usage", "gauge", "Power draw (in W).", $(i++))                                                      # field 155
    metric("total_energy_consumption", "counter", "Total energy consumption since boot (in mJ).", $(i++))             # field 156 FIXED

    # PCIe
    metric("pcie_tx_throughput", "counter", "Total number of bytes transmitted through PCIe TX (in KB)", $(i++))      # field 200 FIXED
    metric("pcie_rx_throughput", "counter", "Total number of bytes received through PCIe RX (in KB)", $(i++))         # field 201 FIXED
    metric("pcie_replay_counter", "counter", "Total number of PCIe retries.", $(i++))                                 # field 202

    # Utilization (the sample period varies depending on the product)
    metric("gpu_utilization", "gauge", "GPU utilization (in %).", $(i++))                                             # field 203
    metric("mem_copy_utilization", "gauge", "Memory utilization (in %).", $(i++))                                     # field 204
    metric("enc_utilization", "gauge", "Encoder utilization (in %).", $(i++))                                         # field 206
    metric("dec_utilization", "gauge", "Decoder utilization (in %).", $(i++))                                         # field 207

    # Errors and violations
    metric("xid_errors", "gauge", "Value of the last XID error encountered.", $(i++))                                 # field 230
    metric("power_violation", "counter", "Throttling duration due to power constraints (in us).", $(i++))             # field 240
    metric("thermal_violation", "counter", "Throttling duration due to thermal constraints (in us).", $(i++))         # field 241
    metric("sync_boost_violation", "counter", "Throttling duration due to sync-boost constraints (in us).", $(i++))   # field 242
    metric("board_limit_violation", "counter", "Throttling duration due to board limit constraints (in us).", $(i++)) # field 243 FIXME
    metric("low_util_violation", "counter", "Throttling duration due to low utilization (in us).", $(i++))            # field 244
    metric("reliability_violation", "counter", "Throttling duration due to reliability constraints (in us).", $(i++)) # field 245 FIXME
    metric("app_clock_violation", "counter", "Total throttling duration (in us).", $(i++))                            # field 246

    # Memory usage
    metric("fb_free", "gauge", "Framebuffer memory free (in MiB).", $(i++))                                           # field 251
    metric("fb_used", "gauge", "Framebuffer memory used (in MiB).", $(i++))                                           # field 252

    # ECC
    metric("ecc_sbe_volatile_total", "counter", "Total number of single-bit volatile ECC errors.", $(i++))            # field 310
    metric("ecc_dbe_volatile_total", "counter", "Total number of double-bit volatile ECC errors.", $(i++))            # field 311
    metric("ecc_sbe_aggregate_total", "counter", "Total number of single-bit persistent ECC errors.", $(i++))         # field 312
    metric("ecc_dbe_aggregate_total", "counter", "Total number of double-bit persistent ECC errors.", $(i++))         # field 313

    # Retired pages
    metric("retired_pages_sbe", "counter", "Total number of retired pages due to single-bit errors.", $(i++))         # field 390
    metric("retired_pages_dbe", "counter", "Total number of retired pages due to double-bit errors.", $(i++))         # field 391
    metric("retired_pages_pending", "counter", "Total number of pages pending retirement.", $(i++))                   # field 392

    # NVLink
    metric("nvlink_flit_crc_error_count_total", "counter", "Total number of NVLink flow-control CRC errors.", $(i++)) # field 409
    metric("nvlink_data_crc_error_count_total", "counter", "Total number of NVLink data CRC errors.", $(i++))         # field 419
    metric("nvlink_replay_error_count_total", "counter", "Total number of NVLink retries.", $(i++))                   # field 429
    metric("nvlink_recovery_error_count_total", "counter", "Total number of NVLink recovery errors.", $(i++))         # field 439
    #metric("nvlink_bandwidth_total", "counter", "Total number of NVLink bandwidth counters for all lanes", $(i++))   # field 449 TODO

    # Flush output file and move it for atomicity
    if (gpu == ngpus - 1) {
        close(out".swp")
        system("mv "out".swp "out)
    }
}' &

wait $!

if [ "${HOST_ENGINE}" = "yes" ]; then
    echo "Stopping NVIDIA host engine..."
    nv-hostengine --term

    if [ -f /run/nvhostengine.pid ]; then
        pid=$(< /run/nvhostengine.pid)

        kill -SIGTERM "${pid}"
        for i in $(seq 1 100); do
            kill -0 "${pid}" 2> /dev/null || break
            sleep 0.1
        done
        if [ $i -eq 100 ]; then
            echo "Could not stop NVIDIA host engine" >&2
            kill -9 "${pid}" 2> /dev/null
            exit 1
        fi
        rm -f /run/nvhostengine.pid
    fi
fi

echo "Done"
exit 0
